{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Janus 介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 引言\n",
    "\n",
    "Janus是一个简单、统一且可扩展的多模态理解与生成模型，其将多模态理解与生成的视觉编码进行解耦，缓解了两个任务潜在存在的冲突。可在未来通过拓展，纳入更多的输入模态。Janus-Pro在此基础上，优化训练策略（包括增加训练步数、调整数据配比等）、增加数据（包括使用合成数据等）、扩大模型规模（扩大到70亿参数），使得模型多模态理解和文本到图像指令遵循能力方面取得了进步。\n",
    "\n",
    "Janus包含2个独立的视觉编码路径，分别用于多模态理解、生成，并带来两个收益：1）缓解了源自多模态理解和生成不同粒度需求的冲突，2）具有灵活性和可扩展性，解耦后，理解和生成任务都可以采用针对其领域最先进的编码技术，未来可输入点云、脑电信号或音频数据，使用统一的Transformer进行处理。\n",
    "<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n",
    "    <img src=\"https://ai-studio-static-online.cdn.bcebos.com/ea0703505b3b40ad923981dbddda20973c81da7a36194e3abc75ad1d9b870ab4\" alt=\"Description\" width=\"50%\" >\n",
    "</div>\n",
    "<center>图1: Janus 架构</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 方法\n",
    "\n",
    "### 2.1 模型架构\n",
    "Janus 的架构如图 1 所示。对于纯文本理解、多模态理解和视觉生成任务，采用独立的编码方法将原始输入转换为特征，然后通过统一的自回归 Transformer 进行处理。具体来说：\n",
    "- 文本理解：我们使用大语言模型（LLM）内置的分词器将文本转换为离散的 ID，并获取每个 ID 对应的特征表示。\n",
    "- 多模态理解：我们使用 SigLIP 编码器从图像中提取高维语义特征。这些特征从 2D 网格展平为 1D 序列，并通过一个理解适配器将这些图像特征映射到 LLM 的输入空间。\n",
    "- 视觉生成：我们使用 VQ 分词器将图像转换为离散的 ID。将 ID 序列展平为 1D 后，使用一个生成适配器将每个 ID 对应的码本嵌入映射到 LLM 的输入空间。\n",
    "然后，我们将这些特征序列连接起来，形成一个多模态特征序列，随后输入到 LLM 中进行处理。在纯文本理解和多模态理解任务中，使用 LLM 内置的预测头进行文本预测；而在视觉生成任务中，使用随机初始化的预测头进行图像预测。整个模型遵循自回归框架，无需特别设计的注意力掩码。\n",
    "\n",
    "### 2.2 Janus 训练\n",
    "Janus的训练分为3个阶段：\n",
    "- 第一阶段：训练Adaptor与Image Head，在嵌入空间创建语言元素与视觉元素之间的联系，使得LLM能够理解图像中的实体，并具备初步视觉生成能力；\n",
    "对于多模态理解，使用来自SHareGPT4V125万个图像-文本配对字幕数据，格式：<图像><文本>；\n",
    "对于视觉生成，使用来自ImageNet1k的120万个样本，格式：<类别名><图像>；\n",
    "\n",
    "- 第二阶段：统一预训练，使用多模态语料库进行统一预训练，学习多模态理解和生成。\n",
    "    - 在该阶段使用纯文本数据、多模态理解数据和视觉生成数据\n",
    "    - 使用ImageNet-1k进行简单的视觉生成训练，随后使用通用文本到图像数据提升模型开放领域的视觉生成能力\n",
    "    - 纯文本数据：DeepSeek-LLM预训练语料库\n",
    "    - 交错的图像-文本数据：WikiHow 和 WIT 数据集；\n",
    "    - 图像Caption数据：来自多个来源的图像，并采用开源多模态模型重新为部分图像添加字幕，数据格式为问答对，如$\\texttt{<caption>}$ Describe the image in detail.$\\texttt{<caption>}$；\n",
    "    - 表格和图表数据：来自 DeepSeek-VL的相应表格和图表数据，数据格式为<question><answer>；\n",
    "    - 视觉生成数据：来自多个数据集的image-caption对以及 200 万个内部数据；在训练过程中，以25%的概率随机仅使用caption的第一句话；ImageNet 样本仅在最初的 120K 训练步骤中出现，其他数据集的图像在后续 60K 步骤中出现；\n",
    "- 第三阶段：监督微调，使用指令微调数据对预训练模型进行微调，以增强其遵循指令和对话的能力。微调除生成编码器之外的所有参数。在监督答案的同时，对系统和用户提示进行遮盖。为了确保Janus在多模态理解和生成方面都具备熟练度，不会针对特定任务分别微调模型。相反，我们使用纯文本对话数据、多模态理解数据和视觉生成数据的混合数据，以确保在各种场景下的多功能性；\n",
    "    - 文本理解：使用来自特定来源的数据；\n",
    "    - 多模态理解：使用来自多个来源的指令调整数据；\n",
    "    - 视觉生成：使用来自部分第二阶段数据集的图像-文本对子集以及 400 万个内部数据；\n",
    "    - 数据格式为：User:$\\texttt{<Input Message>}$ \\n Assistant: $\\texttt{<Response>}$；\n",
    "<div style=\"display: flex; justify-content: center; align-items: center; height: [desired-container-height]px;\">\n",
    "    <img src=\"https://github.com/user-attachments/assets/0035318f-3348-4e5d-9256-a9a3410fa625\" alt=\"Description\" width=\"50%\" >\n",
    "</div>\n",
    "<center>图2: Janus 三阶段训练步骤</center>\n",
    "\n",
    "### 2.3 Janus 推理\n",
    "在推理过程中，Janus 模型采用了一种 Next-token预测的方法。对于纯文本理解和多模态理解，我们遵循从预测分布中顺序采样token的标准做法。对于图像生成，我们利用了无分类器引导（CFG）在训练过程中，我们以10%的概率将文本到图像数据中的文本条件替换为填充token，使模型具备无条件视觉生成能力。对于生成下一个token的概率分布 $l_g$ 的计算公式为 $l_g = l_u + s(l_c - l_u)$ ，$l_c$是条件概率分布，$l_u$ 是条件概率分布，$s$ 是CFG系数，默认情况下 $s$ 为 5.\n",
    "\n",
    "### 2.4 Janus-Pro\n",
    "- 训练策略\n",
    "    - Stage 1: 增加训练步数，在 ImageNet 上充分训练；\n",
    "    - Stage 2: 不再使用 ImageNet，直接使用常规文本到图像数据的训练数据；\n",
    "    - Stage 3: 修改微调过程中的数据集配比，将多模态数据、纯文本数据和文本到图像的比例从 7:3:10 改为 5:1:4；\n",
    "- 数据规模\n",
    "    - 多模态理解\n",
    "        - Stage 2: 增加 9000 万个样本，包括图像字幕数据 YFCC、表格图表文档理解数据 Doc-matrix；\n",
    "        - Stage 3: 加入 DeepSeek-VL2 额外数据集，如 MEME 理解等；\n",
    "    - 视觉生成：真实世界数据可能包含质量不高，导致文本到图像的生成不稳定，产生美学效果不佳的输出，Janus-Pro 使用 7200 万份合成美学数据样本，统一预训练阶段（Stage 2）真实数据与合成数据比例 1:1；\n",
    "- 模型规模\n",
    "    - 将模型参数扩展到 70 亿参数规模；"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 代码解读\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Janus 组网代码介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 类名: JanusMultiModalityCausalLM\n",
    "- 功能: 该类实现了一个多模态因果语言模型，它能够处理图像和文本数据\n",
    "- 实现步骤: \n",
    "    - 初始化：\n",
    "        - 从配置对象中提取各个组件的配置。\n",
    "        - 使用model_name_to_cls函数和配置参数来实例化视觉模型、对齐器、生成视觉模型、生成对齐器、生成头和语言模型。\n",
    "        - 创建一个Embedding层用于图像标识符到Embedding向量的映射。\n",
    "    - 准备输入Embed：\n",
    "        - 重新排列图像数据pixel_values的形状以适应视觉模型的输入要求。\n",
    "        - 使用视觉模型处理图像数据，并通过对齐器生成图像Embed。\n",
    "        - 重新排列图像Embed和掩码的形状以匹配文本输入的形状。\n",
    "        - 处理文本输入input_ids，将其转换为语言模型可以处理的Embed形式。\n",
    "        - 根据掩码将图像嵌入插入到文本嵌入中，生成最终的输入Embed input_embeds。\n",
    "    - 准备生成图像嵌入：\n",
    "        - 使用嵌入层将图像标识符image_ids映射到 Embedding 向量。\n",
    "        - 通过对齐器处理这些 Embedding 向量，生成最终的图像Embedding。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JanusMultiModalityCausalLM(JanusMultiModalityPreTrainedModel):\n",
    "    config_class = MultiModalityConfig\n",
    "\n",
    "    def __init__(self, config: MultiModalityConfig):\n",
    "        super().__init__(config)\n",
    "        vision_config = config.vision_config\n",
    "        vision_cls = model_name_to_cls(vision_config.cls)\n",
    "        self.vision_model = vision_cls(**vision_config.params)\n",
    "        aligner_config = config.aligner_config\n",
    "        aligner_cls = model_name_to_cls(aligner_config.cls)\n",
    "        self.aligner = aligner_cls(aligner_config.params)\n",
    "        gen_vision_config = config.gen_vision_config\n",
    "        gen_vision_cls = model_name_to_cls(gen_vision_config.cls)\n",
    "        self.gen_vision_model = gen_vision_cls()\n",
    "        gen_aligner_config = config.gen_aligner_config\n",
    "        gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)\n",
    "        self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)\n",
    "        gen_head_config = config.gen_head_config\n",
    "        gen_head_cls = model_name_to_cls(gen_head_config.cls)\n",
    "        self.gen_head = gen_head_cls(gen_head_config.params)\n",
    "        self.gen_embed = paddle.nn.Embedding(\n",
    "            num_embeddings=gen_vision_config.params[\"image_token_size\"],\n",
    "            embedding_dim=gen_vision_config.params[\"n_embed\"],\n",
    "        )\n",
    "        language_config = config.language_config\n",
    "        self.language_model = LlamaForCausalLM(language_config)\n",
    "\n",
    "    def prepare_inputs_embeds(\n",
    "        self,\n",
    "        input_ids: paddle.Tensor,\n",
    "        pixel_values: paddle.Tensor,\n",
    "        images_seq_mask: paddle.Tensor,\n",
    "        images_emb_mask: paddle.Tensor,\n",
    "        **kwargs\n",
    "    ):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            input_ids (paddle.Tensor): [b, T]\n",
    "            pixel_values (paddle.Tensor):   [b, n_images, 3, h, w]\n",
    "            images_seq_mask (paddle.Tensor): [b, T]\n",
    "            images_emb_mask (paddle.Tensor): [b, n_images, n_image_tokens]\n",
    "\n",
    "            assert paddle.sum(images_seq_mask) == paddle.sum(images_emb_mask)\n",
    "\n",
    "        Returns:\n",
    "            input_embeds (paddle.Tensor): [b, T, D]\n",
    "        \"\"\"\n",
    "        bs, n = tuple(pixel_values.shape)[0:2]\n",
    "        images = rearrange(pixel_values, \"b n c h w -> (b n) c h w\")\n",
    "        images_embeds = self.aligner(self.vision_model(images))\n",
    "        images_embeds = rearrange(images_embeds, \"(b n) t d -> b (n t) d\", b=bs, n=n)\n",
    "        images_emb_mask = rearrange(images_emb_mask, \"b n t -> b (n t)\")\n",
    "        input_ids[input_ids < 0] = 0\n",
    "        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)\n",
    "        inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]\n",
    "\n",
    "        return inputs_embeds\n",
    "\n",
    "    def prepare_gen_img_embeds(self, image_ids: paddle.Tensor):\n",
    "        return self.gen_aligner(self.gen_embed(image_ids))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Janus 多模态生成代码介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 方法: generate\n",
    "- 参数:\n",
    "    - mmgpt：JanusMultiModalityCausalLM的对象，负责生成图像和文本。\n",
    "    - vl_chat_processor：一个处理器对象，用于处理视觉-语言（VL）聊天数据，包括分词和图像编码等。\n",
    "    - prompt：一个字符串，代表输入给模型的文本提示。\n",
    "    - temperature：一个浮点数，用于调整生成结果的随机性（或称为“温度”）。较低的值会使生成结果更加确定，而较高的值会增加多样性。\n",
    "    - parallel_size：一个整数，表示并行生成图像的数量。\n",
    "    - cfg_weight：一个浮点数，用于在生成过程中调整条件和无条件生成的概率分布（logits）之间的权重。\n",
    "    - image_token_num_per_image：一个整数，表示每张图像生成的token数量。\n",
    "    - img_size：一个整数，表示生成图像的尺寸（假设图像是正方形）。\n",
    "    - patch_size：一个整数，表示图像被分割成的小块（patch）的尺寸\n",
    "- 步骤:\n",
    "    - 文本处理：使用vl_chat_processor的分词器将文本提示编码为输入ID，然后转换为Paddle张量。\n",
    "    - 初始化token：创建一个用于存储输入token和生成图像token的张量。对于并行生成的每个样本，都复制输入token，并在奇数索引的样本中插入填充token。\n",
    "    - 输入Embedding：将token转换为模型可以理解的Embedding形式。\n",
    "    - 生成图像token：通过一个循环，逐步生成图像的每个令牌。在每个步骤中：\n",
    "        - 更新position id 以反映当前token生成的位置序号。\n",
    "        - 使用模型的语言模型部分生成下一个token的概率分布。\n",
    "        - 根据条件和无条件生成的 logits 以及温度调整概率分布。\n",
    "        - 使用paddle.multinomial根据调整后的概率分布采样下一个令牌。\n",
    "        - 使用生成的token生成图像Embedding，并更新输入Embedding以用于下一次迭代。\n",
    "    - 解码图像：将生成的图像token解码为图像数据。\n",
    "    - 后处理和保存：将解码后的图像数据标准化为0-255之间的整数，并保存为JPEG文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(\n",
    "    mmgpt,\n",
    "    vl_chat_processor,\n",
    "    prompt: str,\n",
    "    temperature: float = 1,\n",
    "    parallel_size: int = 2,\n",
    "    cfg_weight: float = 5,\n",
    "    image_token_num_per_image: int = 576,\n",
    "    img_size: int = 384,\n",
    "    patch_size: int = 16,\n",
    "):\n",
    "    input_ids = vl_chat_processor.tokenizer.encode(prompt)\n",
    "    input_ids = paddle.to_tensor(data=input_ids.input_ids, dtype=\"int64\")\n",
    "    tokens = paddle.zeros(shape=(parallel_size * 2, len(input_ids)), dtype=\"int32\")\n",
    "    for i in range(parallel_size * 2):\n",
    "        tokens[i, :] = input_ids\n",
    "        if i % 2 != 0:\n",
    "            tokens[i, 1:-1] = vl_chat_processor.pad_id\n",
    "    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)  # [4, 50, 2048]\n",
    "    generated_tokens = paddle.zeros(shape=(parallel_size, image_token_num_per_image), dtype=\"int32\")\n",
    "    batch_size, seq_length = inputs_embeds.shape[:2]\n",
    "    for i in tqdm(range(image_token_num_per_image)):\n",
    "        batch_size, seq_length = inputs_embeds.shape[:2]\n",
    "\n",
    "        past_key_values_length = outputs.past_key_values[0][0].shape[1] if i != 0 else 0\n",
    "        position_ids = paddle.arange(past_key_values_length, seq_length + past_key_values_length).expand(\n",
    "            (batch_size, seq_length)\n",
    "        )\n",
    "\n",
    "        outputs = mmgpt.language_model.llama(\n",
    "            position_ids=position_ids,\n",
    "            inputs_embeds=inputs_embeds,  # [4, 1, 2048]\n",
    "            use_cache=True,\n",
    "            past_key_values=outputs.past_key_values if i != 0 else None,\n",
    "            return_dict=True,\n",
    "        )\n",
    "\n",
    "        hidden_states = outputs.last_hidden_state\n",
    "        logits = mmgpt.gen_head(hidden_states[:, -1, :])\n",
    "        logit_cond = logits[0::2, :]\n",
    "        logit_uncond = logits[1::2, :]\n",
    "\n",
    "        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)\n",
    "        probs = paddle.nn.functional.softmax(x=logits / temperature, axis=-1)\n",
    "        next_token = paddle.multinomial(x=probs, num_samples=1)\n",
    "\n",
    "        generated_tokens[:, i] = next_token.squeeze(axis=-1)\n",
    "        next_token = paddle.concat(x=[next_token.unsqueeze(axis=1), next_token.unsqueeze(axis=1)], axis=1).reshape(\n",
    "            [-1]\n",
    "        )\n",
    "        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)\n",
    "        inputs_embeds = img_embeds.unsqueeze(axis=1)\n",
    "\n",
    "    dec = mmgpt.gen_vision_model.decode_code(\n",
    "        generated_tokens.to(dtype=\"int32\"), shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]\n",
    "    )\n",
    "    dec = dec.to(\"float32\").cpu().numpy().transpose(0, 2, 3, 1)\n",
    "    dec = np.clip((dec + 1) / 2 * 255, 0, 255)\n",
    "    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)\n",
    "    visual_img[:, :, :] = dec\n",
    "    os.makedirs(\"janus_generated_samples\", exist_ok=True)\n",
    "    for i in range(parallel_size):\n",
    "        save_path = os.path.join(\"janus_generated_samples\", \"img_{}.jpg\".format(i))\n",
    "        PIL.Image.fromarray(visual_img[i]).save(save_path)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
